#!/usr/bin/env python3

import subprocess
import socket
import sys
import os

from ctypes import CDLL

libc = CDLL('libc.so.6')

SO_MARK = 36 # Define socket option for python2 compatibility
SOCCR_MARK = 0xC114
CLONE_NEWNET = 0x40000000
PORT = 8880
NETNS = "criu-net-lock-test"
TIMEOUT = 0.1
SYNCFILE = "zdtm/static/net_lock.sync"

def nsenter():
    with open("/var/run/netns/{}".format(NETNS)) as f:
        libc.setns(f.fileno(), CLONE_NEWNET)

def create_sync_file():
    open(SYNCFILE, "wb").close()

if sys.argv[1] == "--post-start":
    # Add test netns
    subprocess.Popen(["ip", "netns", "add", NETNS]).wait()
    nsenter() # Enter test netns
    subprocess.Popen(["ip", "link", "set", "up", "dev", "lo"]).wait()

    # Lets test know that the netns is initialized successfully
    # by checking the access of SYNCFILE
    create_sync_file()

    # TCP server daemon
    pid = os.fork()
    if(pid == 0):
        os.setsid() # Detach from parent
        # Run server
        srv = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        srv.bind(("127.0.0.1", PORT))
        srv.listen(1)

        # Loop allows zdtm multiple iterations to work (i.e. --iter 3)
        while True:
            # We should accept connections normally from
            # pre-dump client, post-restore client and
            # the pre-restore client using SOCCR_MARK
            # The pre-restore client without SOCCR_MARK
            # should fail with a timeout

            # Accept pre-dump client
            cln, addr = srv.accept()
            cln.sendall(str.encode("--pre-dump"))
            cln.close()

            # Accept pre-restore client with SOCCR_MARK
            srv.setsockopt(socket.SOL_SOCKET, SO_MARK, SOCCR_MARK)
            cln, addr = srv.accept()
            cln.sendall(str.encode("--pre-restore"))
            cln.close()
            srv.setsockopt(socket.SOL_SOCKET, SO_MARK, 0)

            # Accept post-restore client
            cln, addr = srv.accept()
            cln.sendall(str.encode("--post-restore"))
            cln.close()

        # Server will be closed when zdtm sends SIGKILL

if sys.argv[1] == "--pre-dump":
    # Network is not locked yet
    # Client should be able to connect normally

    nsenter() # Enter test netns

    socket.setdefaulttimeout(TIMEOUT)
    cln = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    try:
        cln.connect(("127.0.0.1", PORT))
        resp = cln.recv(100).decode()
        if resp != sys.argv[1]:
            print("Expected '{}', got '{}'".format(sys.argv[1], resp))
            sys.exit(1)
    except socket.timeout:
        print("Timeout when trying to connect to server")
        sys.exit(1)
    else:
        sys.exit(0)
    finally:
        cln.close()

if sys.argv[1] == "--pre-restore":
    # Network should be locked now
    # Only packets with SOCCR_MARK should be allowed
    # Client should timeout when connecting without SOCCR_MARK

    nsenter()

    socket.setdefaulttimeout(TIMEOUT)
    # SOCCR_MARK
    try:
        cln = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        cln.setsockopt(socket.SOL_SOCKET, SO_MARK, SOCCR_MARK)
        cln.connect(("localhost", PORT))
        resp = cln.recv(100).decode()
        if resp != sys.argv[1]:
            print("Expected '{}', got '{}'".format(sys.argv[1], resp))
            sys.exit(1)
    except socket.timeout:
        print("Timeout when trying to connect to server")
        sys.exit(1)
    finally:
        cln.close()

    # No SOCCR_MARK
    try:
        cln = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        cln.connect(("localhost", PORT))
    except socket.timeout:
        sys.exit(0)
    else:
        print("Network should be locked, but client connected normally")
        sys.exit(1)
    finally:
        cln.close()

if sys.argv[1] == "--post-restore":
    # Network should be unlocked now
    # Client should be able to connect normally

    nsenter()

    socket.setdefaulttimeout(TIMEOUT)
    cln = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    try:
        cln.connect(("localhost", PORT))
        resp = cln.recv(100).decode()
        if resp != sys.argv[1]:
            print("Expected '{}', got '{}'".format(sys.argv[1], resp))
            sys.exit(1)
    except socket.timeout:
        print("Timeout when trying to connect to server")
        sys.exit(1)
    else:
        sys.exit(0)
    finally:
        cln.close()

if sys.argv[1] == "--clean":
    # Delete test netns
    subprocess.Popen(["ip", "netns", "delete", NETNS]).wait()
    # Delete sync file
    os.remove(SYNCFILE)
